{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers, optimizers, datasets\n",
    "\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}\n",
    "\n",
    "def mnist_dataset():\n",
    "    (x, y), (x_test, y_test) = datasets.mnist.load_data()\n",
    "    #normalize\n",
    "    x = x/255.0\n",
    "    x_test = x_test/255.0\n",
    "    \n",
    "    return (x, y), (x_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd')]\n"
     ]
    }
   ],
   "source": [
    "print(list(zip([1, 2, 3, 4], ['a', 'b', 'c', 'd'])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class myModel:\n",
    "    def __init__(self):\n",
    "        ####################\n",
    "        '''声明模型对应的参数'''\n",
    "        ####################\n",
    "    def __call__(self, x):\n",
    "        ####################\n",
    "        '''实现模型函数体，返回未归一化的logits'''\n",
    "        ####################\n",
    "        return logits\n",
    "        \n",
    "model = myModel()\n",
    "\n",
    "optimizer = optimizers.Adam()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 计算 loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tf.function\n",
    "def compute_loss(logits, labels):\n",
    "    return tf.reduce_mean(\n",
    "        tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
    "            logits=logits, labels=labels))\n",
    "\n",
    "@tf.function\n",
    "def compute_accuracy(logits, labels):\n",
    "    predictions = tf.argmax(logits, axis=1)\n",
    "    return tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))\n",
    "\n",
    "@tf.function\n",
    "def train_one_step(model, optimizer, x, y):\n",
    "    with tf.GradientTape() as tape:\n",
    "        logits = model(x)\n",
    "        loss = compute_loss(logits, y)\n",
    "\n",
    "    # compute gradient\n",
    "    trainable_vars = [model.W1, model.W2, model.b1, model.b2]\n",
    "    grads = tape.gradient(loss, trainable_vars)\n",
    "    for g, v in zip(grads, trainable_vars):\n",
    "        v.assign_sub(0.01*g)\n",
    "\n",
    "    accuracy = compute_accuracy(logits, y)\n",
    "\n",
    "    # loss and accuracy is scalar tensor\n",
    "    return loss, accuracy\n",
    "\n",
    "@tf.function\n",
    "def test(model, x, y):\n",
    "    logits = model(x)\n",
    "    loss = compute_loss(logits, y)\n",
    "    accuracy = compute_accuracy(logits, y)\n",
    "    return loss, accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实际训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0 : loss 0.7526799 ; accuracy 0.8300833\n",
      "epoch 1 : loss 0.7520259 ; accuracy 0.83023334\n",
      "epoch 2 : loss 0.7513737 ; accuracy 0.83031666\n",
      "epoch 3 : loss 0.75072306 ; accuracy 0.8305\n",
      "epoch 4 : loss 0.7500742 ; accuracy 0.8305167\n",
      "epoch 5 : loss 0.74942684 ; accuracy 0.83063334\n",
      "epoch 6 : loss 0.748781 ; accuracy 0.83071667\n",
      "epoch 7 : loss 0.74813706 ; accuracy 0.8308\n",
      "epoch 8 : loss 0.74749464 ; accuracy 0.8308667\n",
      "epoch 9 : loss 0.74685377 ; accuracy 0.83095\n",
      "epoch 10 : loss 0.74621457 ; accuracy 0.8311167\n",
      "epoch 11 : loss 0.74557704 ; accuracy 0.8312167\n",
      "epoch 12 : loss 0.7449409 ; accuracy 0.83133334\n",
      "epoch 13 : loss 0.7443065 ; accuracy 0.8314667\n",
      "epoch 14 : loss 0.7436737 ; accuracy 0.8315833\n",
      "epoch 15 : loss 0.74304247 ; accuracy 0.83165\n",
      "epoch 16 : loss 0.74241275 ; accuracy 0.8318167\n",
      "epoch 17 : loss 0.74178463 ; accuracy 0.8319333\n",
      "epoch 18 : loss 0.74115807 ; accuracy 0.83208334\n",
      "epoch 19 : loss 0.7405333 ; accuracy 0.83218336\n",
      "epoch 20 : loss 0.73990965 ; accuracy 0.8323333\n",
      "epoch 21 : loss 0.7392879 ; accuracy 0.83241665\n",
      "epoch 22 : loss 0.7386676 ; accuracy 0.8325833\n",
      "epoch 23 : loss 0.7380488 ; accuracy 0.8327\n",
      "epoch 24 : loss 0.73743147 ; accuracy 0.83278334\n",
      "epoch 25 : loss 0.7368157 ; accuracy 0.8329167\n",
      "epoch 26 : loss 0.7362015 ; accuracy 0.8330167\n",
      "epoch 27 : loss 0.7355887 ; accuracy 0.83318335\n",
      "epoch 28 : loss 0.7349776 ; accuracy 0.8332\n",
      "epoch 29 : loss 0.73436785 ; accuracy 0.8332833\n",
      "epoch 30 : loss 0.73375976 ; accuracy 0.8333333\n",
      "epoch 31 : loss 0.733153 ; accuracy 0.83343333\n",
      "epoch 32 : loss 0.73254776 ; accuracy 0.8335\n",
      "epoch 33 : loss 0.731944 ; accuracy 0.83351666\n",
      "epoch 34 : loss 0.7313418 ; accuracy 0.83365\n",
      "epoch 35 : loss 0.73074096 ; accuracy 0.83376664\n",
      "epoch 36 : loss 0.7301418 ; accuracy 0.83391666\n",
      "epoch 37 : loss 0.72954386 ; accuracy 0.83411664\n",
      "epoch 38 : loss 0.7289475 ; accuracy 0.83416665\n",
      "epoch 39 : loss 0.7283526 ; accuracy 0.83421665\n",
      "epoch 40 : loss 0.727759 ; accuracy 0.83421665\n",
      "epoch 41 : loss 0.72716707 ; accuracy 0.8343833\n",
      "epoch 42 : loss 0.7265763 ; accuracy 0.8344833\n",
      "epoch 43 : loss 0.72598726 ; accuracy 0.83461666\n",
      "epoch 44 : loss 0.7253995 ; accuracy 0.83468336\n",
      "epoch 45 : loss 0.7248133 ; accuracy 0.83488333\n",
      "epoch 46 : loss 0.72422826 ; accuracy 0.83493334\n",
      "epoch 47 : loss 0.7236447 ; accuracy 0.83501667\n",
      "epoch 48 : loss 0.72306263 ; accuracy 0.8351167\n",
      "epoch 49 : loss 0.7224819 ; accuracy 0.83515\n",
      "test loss 0.69837373 ; accuracy 0.846\n"
     ]
    }
   ],
   "source": [
    "train_data, test_data = mnist_dataset()\n",
    "for epoch in range(50):\n",
    "    loss, accuracy = train_one_step(model, optimizer, \n",
    "                                    tf.constant(train_data[0], dtype=tf.float32), \n",
    "                                    tf.constant(train_data[1], dtype=tf.int64))\n",
    "    print('epoch', epoch, ': loss', loss.numpy(), '; accuracy', accuracy.numpy())\n",
    "loss, accuracy = test(model, \n",
    "                      tf.constant(test_data[0], dtype=tf.float32), \n",
    "                      tf.constant(test_data[1], dtype=tf.int64))\n",
    "\n",
    "print('test loss', loss.numpy(), '; accuracy', accuracy.numpy())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
